{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import polars as pl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "# CHANGE\n",
    "path_data = '../data/' # aggregate data (included in this package)\n",
    "path_charts = '' # charts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Open the data\n",
    "df_plot = pd.read_csv(path_data + 'accuracy_figure.csv', index_col=0)\n",
    "df_plot.index = df_plot.index.astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Plot with specific colors and line styles\n",
    "ax = df_plot.plot(\n",
    "    grid=True,\n",
    "    style={\n",
    "        'accuracy': 'b--',  # Blue dashed line\n",
    "        'weighted accuracy': 'b-',  # Blue solid line\n",
    "        'percentage matched': 'g--',  # Green dashed line\n",
    "        'weighted percentage matched': 'g-',  # Green solid line\n",
    "    }\n",
    ")\n",
    "ax.set_xlabel('Cutoff')\n",
    "ax.set_ylim(0, 1.1)  # Set vertical axis range from 0 to 1.2\n",
    "ax.axvline(0.7, color='red', linestyle='--');  # cutoff s.t weighted accuracy = 0.9\n",
    "ax.legend(loc='upper left', bbox_to_anchor=(0, 0.3))\n",
    "\n",
    "ax.get_figure().savefig(path_charts + 'accuracy_and_perc_matched.eps', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick cutoff s.t. weighted accuracy is closest to 0.9\n",
    "\n",
    "(\n",
    "    pl.DataFrame(\n",
    "        {\n",
    "            'weighted_accuracy': results['weighted_accuracy'],\n",
    "            'cutoff': results['cutoff']\n",
    "        }\n",
    "    )\n",
    "    .with_columns(((pl.col('weighted_accuracy') - 0.9).abs()).alias('distance_to_0.9'))\n",
    "    .sort('distance_to_0.9')\n",
    "    .head(4)\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.12 ('env_indeed2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "26da711ab583a058e13fb43990c4acfd219f633b35c618763404a0fc5624b2b9"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
